import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from transformers import GPTJForCausalLM, GPT2Tokenizer
from transformers import GPTNeoForCausalLM, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer
from transformers import set_seed
# from transformers import GPT2Tokenizer, OPTForCausalLM
import json
import argparse
import random
import pickle
from tqdm import tqdm
from data_utils import get_masked_edits, process_mquake_remastered_cf_6334
import pandas as pd
from huggingface_hub import login
from vllm import LLM, SamplingParams


def icl_lm_eval(model, sampling_params, icl_example, target):
    with torch.no_grad():
        post_edit_tokens = model.generate([icl_example], sampling_params, use_tqdm=False) # 
        post_edit_tokens = post_edit_tokens[0].outputs[0].text.strip()

        print('\nInput: ', icl_example, '\nTarget: ', target, '\n\nOutput: ', post_edit_tokens)
        print('==' * 50)
    return post_edit_tokens


def add_padding(tokenizer, model):
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))
    try:
        model.transformer.wte.weight.data[-1] = model.transformer.wte.weight.data.mean(0)
    except AttributeError:
        model.model.embed_tokens.weight.data[-1] = model.model.embed_tokens.weight.data.mean(0)


def construct_icl_examples(idx, demos, corpus_idx):
    order = [2, 1, 2, 0, 1, 2, 2, 0, 2, 2, 1, 0, 2, 1, 2, 0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2]
    random.shuffle(order)
    icl_examples = []
    try:
        demo_ids = corpus_idx[idx]
    except IndexError:
        demo_ids = corpus_idx[random.choice([i for i in range(len(corpus_idx))])]
    demo_ids = demo_ids[:len(order)]
    for demo_id, o in zip(demo_ids, order):
        line = demos[demo_id-2000]
        new_fact = line['requested_rewrite']['prompt'].format(line['requested_rewrite']['subject'])
        target_new = line['requested_rewrite']['target_new']['str']
        target_true = line['requested_rewrite']['target_true']['str']
        
        if o == 0:
            icl_examples.append(f'New Fact: {new_fact} {target_new}\nPrompt: {new_fact} {target_new}\n\n')
        elif o == 1:
            prompt = random.choice(line['paraphrase_prompts'])
            icl_examples.append(f'New Fact: {new_fact} {target_new}\nPrompt: {prompt} {target_new}\n\n')
        elif o == 2:
            prompt = random.choice(line['neighborhood_prompts'])
            icl_examples.append(f'New Fact: {new_fact} {target_new}\nPrompt: {prompt} {target_true}\n\n')
    icl_examples.reverse()
    return icl_examples

    
def ike_eval_loop(mquake_dataset, edited_caseid, model, sampling_params, result_file_path, 
                  masking, dataset_name, demos, corpus_idx):
    dataset = mquake_dataset.get_dataset()
    rand_list = mquake_dataset.get_randlist()
    
    if dataset_name != 'CF-6334' and masking:
        new_facts = None
    else:
        new_facts = set()
        
        for d in mquake_dataset.get_dataset():
            if d['case_id'] not in rand_list:
                continue
            for r in d["requested_rewrite"]:
                new_facts.add(f'{r["prompt"].format(r["subject"])} {r["target_new"]["str"]}')
        new_facts = list(new_facts)
        if not new_facts:
            new_facts = ["No relevant fact."]
    
    example_idx = 0
    raw_answer_dict = {}
    for i in tqdm(range(len(dataset))):
        d = dataset[i]
        raw_answer_dict[d['case_id']] = {'edited': d['case_id'] in edited_caseid,
                                         'answers': []}
        
        if dataset_name != 'CF-6334' and masking:
            new_facts, _, _, _ = mquake_dataset.get_edits_without_contamination(rand_list, d)
            if not new_facts:
                new_facts = ["No relevant fact."]
        
        
        if d['case_id'] not in edited_caseid:
            target = dataset[i]['answer']
        else:
            target = dataset[i]['new_answer']
            
        icl_examples = construct_icl_examples(example_idx, demos, corpus_idx=corpus_idx)
        example_idx += 1
        for prompt_question in dataset[i]['questions']:
            ike_example = '\n'.join(icl_examples) + '\n'.join(new_facts) + '\nPrompt: ' + prompt_question + ' Answer concisely.' + '\n\n'
            model_output = icl_lm_eval(model, sampling_params, ike_example, target)
            raw_answer_dict[d['case_id']]['answers'].append(model_output)

    with open(result_file_path, 'w') as fp:
        json.dump(raw_answer_dict, fp, indent = 4) 